{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fragment NGram  \n",
    "2019.10.21  \n",
    "yoh Noguchi (edited by Stephen Wu on 2019.10.24)\n",
    "\n",
    "__Do not use too many data to begin with!!!__\n",
    "\n",
    "Goal of this script: Create new molecules by modifying substructures in a list of initial molecules based on a pre-trained fragment NGram generator.\n",
    "\n",
    "Steps:\n",
    "1. Prepare a list of initial molecules (in SMILES format)\n",
    "2. Fragmentation using RECAP in RDKit\n",
    "3. Keep bigger part as base structures, and extract only the smaller parts\n",
    "4. Modify the extracted small fragments using a pre-trained NGram\n",
    "5. Exhaustive combination of all new small fragments generated above with the base structures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from xenonpy.descriptor import Fingerprints\n",
    "import matplotlib.pyplot as plt\n",
    "from xenonpy.inverse.iqspr import GaussianLogLikelihood\n",
    "from xenonpy.inverse.iqspr import NGram\n",
    "from xenonpy.inverse.iqspr import IQSPR\n",
    "import csv\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle as pk\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import Draw, Recap\n",
    "from rdkit.Chem.rdMolDescriptors import GetMorganFingerprint\n",
    "import os\n",
    "from xenonpy.descriptor import FrozenFeaturizer\n",
    "from xenonpy.descriptor.base import BaseFeaturizer\n",
    "from xenonpy.descriptor.base import BaseDescriptor\n",
    "from xenonpy.descriptor import RDKitFP, MACCS, ECFP, AtomPairFP, TopologicalTorsionFP, FCFP\n",
    "from bayes_opt import BayesianOptimization\n",
    "from rdkit.Chem import BRICS\n",
    "from collections import Counter\n",
    "from scipy.special import comb\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### User parameters\n",
    "- `CSV_PATH`: directory of csv file to be imported\n",
    "- `NGRAM_PATH`: directory of saved NGram model\n",
    "- `FRAGMENT_LENGTH`: max length of SMILES to be considered a small fragment\n",
    "- `BASE_LENGTH`: max length of SMILES to be considered a base structure\n",
    "- `CREATED_FRAGMENTS_NUMBER`: number of modified fragments per each extracted small fragment\n",
    "※ `len(smis_Fragment)` * `CREATED_FRAGMENTS_NUMBER` = total of number of new fragments\n",
    "- `OUTPUT_FILENAME`: file name of output csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "CSV_PATH = 'XXXXX.csv'\n",
    "NGRAM_PATH = 'ngram_reorder_12_O20_peter.obj'\n",
    "FRAGMENT_LENGTH = 30\n",
    "BASE_LENGTH = 50\n",
    "CREATED_FRAGMENTS_NUMBER = 5\n",
    "OUTPUT_FILENAME = 'test.csv'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read in data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(CSV_PATH).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verify convertability of SMILES to RDKit MOL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "check = []\n",
    "for i, smiles in enumerate(data['SMILES']):\n",
    "    mol = Chem.MolFromSmiles(smiles)\n",
    "    if mol is None:\n",
    "        check.append(False)\n",
    "    if mol is not None:\n",
    "        if '.' in smiles:\n",
    "            check.append(False)\n",
    "        else:\n",
    "            check.append(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### List of SMILES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles_list = list(data['SMILES'][check].unique())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fragmentation（RECAP）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "frag_list = []\n",
    "for i in range(len(smiles_list)):\n",
    "    mol_fu = (Chem.MolFromSmiles(smiles_list[i]))\n",
    "    decomp = Chem.Recap.RecapDecompose(mol_fu)\n",
    "    first_gen = [node.mol for node in decomp.children.values()]\n",
    "    for j in range(len(first_gen)):\n",
    "        smiles = Chem.MolToSmiles(first_gen[j])\n",
    "        frag_list.append(smiles)\n",
    "        \n",
    "frag_list = list(set(frag_list))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read in NGram model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Ngram model  loding for Fragments\n",
    "with open(NGRAM_PATH, 'rb') as f:\n",
    "    n_gram = pk.load(f)\n",
    "\n",
    "setattr(n_gram,'min_len',2)\n",
    "n_gram.sample_order = (1, 20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Categorize fragments\n",
    "- `smis_Fragment`: small fragments\n",
    "- `smis_Base`: base structures\n",
    "- `smis_Large`: big structures to be filtered out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "smis_Fragment = []\n",
    "smis_Base = []\n",
    "smis_Large = []\n",
    "for smi in frag_list:\n",
    "    if len(smi) < FRAGMENT_LENGTH:\n",
    "        smis_Fragment.append(smi)\n",
    "    elif len(smi) < BASE_LENGTH:\n",
    "        smis_Base.append(smi)\n",
    "    else:\n",
    "        smis_Large.append(smi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 【Function】combining fragment with base structure\n",
    "__\\*A B\\* -> BA__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combi_smiles(smis_frag, smis_base):\n",
    "    smis_frag = smis_frag\n",
    "    smis_base = smis_base\n",
    "\n",
    "    # prepare NGram object for use of ext. SMILES\n",
    "    from xenonpy.inverse.iqspr import NGram\n",
    "    ngram = NGram()\n",
    "\n",
    "    # check position of '*'\n",
    "    mols_base = Chem.MolFromSmiles(smis_base)\n",
    "    idx_base = [i for i in range(mols_base.GetNumAtoms()) if mols_base.GetAtomWithIdx(i).GetSymbol() == '*']\n",
    "\n",
    "    # rearrange base SMILES to avoid 1st char = '*'\n",
    "    if idx_base[0] == 0:\n",
    "        smis_base_head = Chem.MolToSmiles(mols_base,rootedAtAtom=1)\n",
    "    else:\n",
    "        smis_base_head = Chem.MolToSmiles(mols_base,rootedAtAtom=0)\n",
    "\n",
    "    # converge base to ext. SMILES and pick insertion location\n",
    "    esmi_base = ngram.smi2esmi(smis_base_head)\n",
    "    esmi_base = esmi_base[:-1]\n",
    "    idx_base = esmi_base.index[esmi_base['esmi'] == '*'].tolist()\n",
    "\n",
    "    # rearrange fragment to have 1st char = '*' and convert to ext. SMILES\n",
    "    mols_frag = Chem.MolFromSmiles(smis_frag)\n",
    "    idx_frag = [i for i in range(mols_frag.GetNumAtoms()) if mols_frag.GetAtomWithIdx(i).GetSymbol() == '*']\n",
    "    smis_frag_head = Chem.MolToSmiles(mols_frag,rootedAtAtom=idx_frag[0])\n",
    "    esmi_frag = ngram.smi2esmi(smis_frag_head)\n",
    "\n",
    "    # remove leading '*' and last '!'\n",
    "    esmi_frag = esmi_frag[1:-1]\n",
    "\n",
    "    # check open rings of base SMILES\n",
    "    nRing_base = esmi_base['n_ring'].loc[idx_base[0]]\n",
    "\n",
    "    # re-number rings in fragment SMILES\n",
    "    esmi_frag['n_ring'] = esmi_frag['n_ring'] + nRing_base\n",
    "\n",
    "    # delete '*' at the insertion location\n",
    "    esmi_base = esmi_base.drop(idx_base[0]).reset_index(drop=True)\n",
    "\n",
    "    # combine base with the fragment\n",
    "    ext_smi = pd.concat([esmi_base.iloc[:idx_base[0]], esmi_frag, esmi_base.iloc[idx_base[0]:]]).reset_index(drop=True)\n",
    "    new_pd_row = {'esmi': '!', 'n_br': 0, 'n_ring': 0, 'substr': ['!']}\n",
    "    ext_smi.append(new_pd_row, ignore_index=True)\n",
    "\n",
    "    fin_smi = ngram.esmi2smi(ext_smi)\n",
    "    mol_fin = Chem.MolFromSmiles(fin_smi)\n",
    "    #return mol_fin\n",
    "    return fin_smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 【Function】loop `combi_smiles` over all the base structures and list of fragments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loop_frag(base_list, frag_list):\n",
    "    comb_smi_list = []\n",
    "    for smi in tqdm(base_list):\n",
    "        results_list = Parallel(n_jobs=-1)([delayed(combi_smiles)(smi,s) for s in frag_list])\n",
    "        comb_smi_list.extend(results_list)\n",
    "    return comb_smi_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 【Function】Generating new fragments based on an initial fragment using pre-trained NGram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_Fragments(smi, N_frag, n_gram, max_iter = 100):\n",
    "    f_list = []\n",
    "    len_smi = len(smi)\n",
    "    num_min = int(len_smi/10)\n",
    "    num_max = (len_smi - 1)\n",
    "    n_gram.set_params(del_range=[num_min,num_max],max_len=1500, reorder_prob=0)\n",
    "    \n",
    "    for _ in range(max_iter):\n",
    "        smis_Ngram = n_gram.proposal([smi for _ in range(N_frag-len(f_list))])\n",
    "        f_list += [x for x in list(set(smis_Ngram)) if x.count('*') == 1]\n",
    "        if len(f_list) == N_frag:\n",
    "            break\n",
    "    \n",
    "    return f_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate new fragments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "Frangments_list = []\n",
    "escape_num = 0\n",
    "#smi_list.append(smis_Fragment[0])\n",
    "results_list = Parallel(n_jobs=-1)([delayed(create_Fragments)(s, CREATED_FRAGMENTS_NUMBER, n_gram) for s in smis_Fragment])\n",
    "\n",
    "for res in results_list:\n",
    "    Frangments_list.extend(res) \n",
    "#Frangments_list = list(set(Frangments_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Combine list of fragments with list of base structures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:03<00:00,  1.34it/s]\n"
     ]
    }
   ],
   "source": [
    "res = loop_frag(smis_Base, Frangments_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Output csv file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_result = pd.Series(res, name='SMILES')\n",
    "s_result.to_csv(OUTPUT_FILENAME, index=False, header='SMILES')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
